{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Robot Localization, in Jax\n",
    "\n",
    "This runs the robot localization benchmark in Jax.  We only compute the linearization, and do not implement the optimization loop.  We also compute a large number of linearizations in batch.\n",
    "\n",
    "This can do the experiment on either CPU or GPU.\n",
    "\n",
    "See [the paper](https://symforce.org/paper) for more information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as onp\n",
    "import itertools\n",
    "import jaxlie\n",
    "from jax import numpy as np\n",
    "import jax\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set to CPU\n",
    "# Comment out to use GPU/TPU\n",
    "jax.config.update(\"jax_platform_name\", \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print the platform (CPU/GPU) we're using\n",
    "jax.lib.xla_bridge.get_backend().platform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def between(a, b):\n",
    "    return a.inverse() @ b\n",
    "\n",
    "\n",
    "def local_coordinates(a, b):\n",
    "    return between(a, b).log()\n",
    "\n",
    "\n",
    "# https://github.com/brentyi/jaxlie/blob/9f177f2640641c38782ec1dc07709a41ea7713ea/jaxlie/manifold/_manifold_helpers.py\n",
    "def matching_residual(world_T_body, world_t_landmark, body_t_landmark, sigma):\n",
    "    residual = (\n",
    "        lambda parameters: (jaxlie.SE3(parameters).apply(body_t_landmark) - world_t_landmark)\n",
    "        / sigma\n",
    "    )\n",
    "    # TODO: fwd or reverse here?\n",
    "    residual_D_storage = jax.jacfwd(residual)(world_T_body.parameters())\n",
    "    storage_D_tangent = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_body)\n",
    "    J = residual_D_storage @ storage_D_tangent\n",
    "    return residual(world_T_body.parameters()), J\n",
    "\n",
    "\n",
    "def odometry_residual(world_T_a, world_T_b, a_T_b, diagonal_sigmas):\n",
    "    storage_D_tangent_a = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_a)\n",
    "    storage_D_tangent_b = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_b)\n",
    "    residual = (\n",
    "        lambda parameters: local_coordinates(\n",
    "            between(jaxlie.SE3(parameters[:7]), jaxlie.SE3(parameters[7:])), a_T_b\n",
    "        )\n",
    "        / diagonal_sigmas\n",
    "    )\n",
    "    # TODO: fwd or reverse here?\n",
    "    residual_D_storage = jax.jacfwd(residual)(\n",
    "        np.concatenate((world_T_a.parameters(), world_T_b.parameters()))\n",
    "    )\n",
    "    J = residual_D_storage @ np.block(\n",
    "        [[storage_D_tangent_a, np.zeros((7, 6))], [np.zeros((7, 6)), storage_D_tangent_b]]\n",
    "    )\n",
    "    return residual(np.concatenate((world_T_a.parameters(), world_T_b.parameters()))), J"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://brentyi.github.io/jaxlie/vmap_usage/\n",
    "pose_matching_residual = jax.vmap(matching_residual, (None, 0, 0, None), (0, 0))\n",
    "full_matching_residual = jax.vmap(pose_matching_residual, (0, None, 0, None), (0, 0))\n",
    "full_odometry_residual = jax.vmap(odometry_residual, (0, 0, 0, None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def problem_linearization(\n",
    "    poses,\n",
    "    odometry_relative_pose_measurements,\n",
    "    world_t_landmark,\n",
    "    body_t_landmark,\n",
    "    measurement_sigma,\n",
    "    odometry_diagonal_sigmas,\n",
    "):\n",
    "    matching_b, matching_J = full_matching_residual(\n",
    "        poses, world_t_landmark, body_t_landmark, measurement_sigma\n",
    "    )\n",
    "    odometry_b, odometry_J = full_odometry_residual(\n",
    "        jaxlie.SE3(wxyz_xyz=poses.wxyz_xyz[:-1]),\n",
    "        jaxlie.SE3(wxyz_xyz=poses.wxyz_xyz[1:]),\n",
    "        odometry_relative_pose_measurements,\n",
    "        odometry_diagonal_sigmas,\n",
    "    )\n",
    "    return matching_b, matching_J, odometry_b, odometry_J\n",
    "\n",
    "\n",
    "problem_linearization = jax.jit(jax.vmap(problem_linearization, (0, 0, 0, 0, None, None)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def time_func(f, key, calls):\n",
    "    start = time.perf_counter()\n",
    "    for _ in range(calls):\n",
    "        f(key)\n",
    "        _, key = jax.random.split(key)\n",
    "    end = time.perf_counter()\n",
    "    return (end - start) / calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(42)\n",
    "\n",
    "nbatches = reversed([1, 1e1, 1e2, 1e3, 1e4, 1e5])\n",
    "# nposes = reversed([3, 10, 1e2, 1e3])\n",
    "# nlandmarks = reversed([5, 20, 1e2, 1e3, 1e4])\n",
    "nposes = [5]\n",
    "nlandmarks = [20]\n",
    "for BATCH, NUM_POSES, NUM_LANDMARKS in itertools.product(nbatches, nposes, nlandmarks):\n",
    "    BATCH = int(BATCH)\n",
    "    NUM_POSES = int(NUM_POSES)\n",
    "    NUM_LANDMARKS = int(NUM_LANDMARKS)\n",
    "\n",
    "    storage = np.zeros((BATCH, NUM_POSES, 7))\n",
    "    storage = storage.at[:, 0].set(1)\n",
    "    poses = jaxlie.SE3(wxyz_xyz=storage)\n",
    "\n",
    "    odometry_relative_pose_measurements = jaxlie.SE3(wxyz_xyz=poses.wxyz_xyz[..., :-1, :])\n",
    "    world_t_landmark = jax.random.normal(key, (BATCH, NUM_LANDMARKS, 3))\n",
    "    _, key = jax.random.split(key)\n",
    "    body_t_landmark = jax.random.normal(key, (BATCH, NUM_POSES, NUM_LANDMARKS, 3))\n",
    "    _, key = jax.random.split(key)\n",
    "    measurement_sigma = 1\n",
    "    odometry_sigmas = np.ones(6)\n",
    "\n",
    "    problem_linearization(\n",
    "        poses,\n",
    "        odometry_relative_pose_measurements,\n",
    "        world_t_landmark,\n",
    "        body_t_landmark,\n",
    "        measurement_sigma,\n",
    "        odometry_sigmas,\n",
    "    )\n",
    "\n",
    "    def random_linearization(key):\n",
    "        world_t_landmark_new = world_t_landmark.at[0, 0].set(jax.random.normal(key))\n",
    "        return problem_linearization(\n",
    "            poses,\n",
    "            odometry_relative_pose_measurements,\n",
    "            world_t_landmark_new,\n",
    "            body_t_landmark,\n",
    "            measurement_sigma,\n",
    "            odometry_sigmas,\n",
    "        )\n",
    "\n",
    "    t = time_func(random_linearization, key, 10)\n",
    "\n",
    "    def random_ninearization(key):\n",
    "        world_t_landmark_new = world_t_landmark.at[0, 0].set(jax.random.normal(key))\n",
    "        return world_t_landmark_new\n",
    "\n",
    "    _, key = jax.random.split(key)\n",
    "    t2 = time_func(random_ninearization, key, 10)\n",
    "\n",
    "    print(\n",
    "        f\"{BATCH:>8}   {NUM_POSES:>8}   {NUM_LANDMARKS:>8}   {t:10.5}   {t2:10.5}   {t - t2:10.5}   {(t - t2)/BATCH:10.5}\"\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "symforce-paper",
   "language": "python",
   "name": "symforce-paper"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
